{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AVv_M1Dz9TDz"
      },
      "source": [
        "# 通过引入语义缓存到 FAISS 中以增强 RAG 系统的性能\n",
        "\n",
        "_作者:[Pere Martra](https://github.com/peremartra)_\n",
        "\n",
        "在这个 notebook 中，我们将使用一个现成的模型和 Chroma 数据库来搭建一个常见的 RAG 系统。**但我们会加入一个新功能，就是一个语义缓存系统，它会保存用户的各种问题，并决定是直接用数据库的信息来回答问题，还是用之前保存的问题答案。**\n",
        "\n",
        "这个语义缓存系统的目的是找出用户提出的问题中哪些是相似的或者是一样的。如果找到了一个之前问过的问题，系统就会直接用缓存里的答案来回答，这样就不用再去数据库里找了。\n",
        "\n",
        "因为这个系统会考虑问题的实际意思，所以即使问题表达的方式不同，或者有些小错误，比如拼写或句子结构不对，系统也能识别出用户其实是在问同一个问题。\n",
        "\n",
        "比如，像 **法国的首都是什么？**、**告诉我法国的首都叫什么？** 和 **法国的首都是什么？** 这样的问题，虽然问法不一样，但都是在问同一个事情。\n",
        "\n",
        "虽然根据问题的不同，模型的回答可能会有点不一样，但基本上从数据库里拿到的信息应该是相同的。这就是为什么我们把缓存系统放在用户和数据库之间，而不是用户和语言模型之间。\n",
        "\n",
        "\n",
        "\n",
        "<img src=\"https://huggingface.co/datasets/huggingface/cookbook-images/resolve/main/semantic_cache.jpg\">\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5gtBERjX1vFd"
      },
      "source": [
        "大多数教程指导你创建一个 RAG 系统，这些教程都是为单个用户设计的，用于在测试环境中运行。换句话说，就是在笔记本中与本地向量数据库交互，以及进行 API 调用或使用本地存储的模型。\n",
        "\n",
        "当尝试将其中一种模型过渡到生产环境时，这种架构很快就显得不够用了，在生产环境中，它们可能会遇到从几十到成千上万次的重复请求。\n",
        "\n",
        "提高性能的一种方法是通过一个或多个语义缓存。这个缓存保留了以前请求的结果，并且在解决新请求之前，它会检查是否之前收到过类似的请求。如果是这样，它就不会重新执行过程，而是从缓存中检索信息。\n",
        "\n",
        "在 RAG 系统中，有两个耗时的点：\n",
        "\n",
        "* 检索用于构建丰富提示的信息：\n",
        "* 调用大型语言模型以获得响应。\n",
        "\n",
        "在这两点上，都可以实现语义缓存系统，我们甚至可以有两个缓存，每个点一个。\n",
        "\n",
        "将缓存系统放在模型的响应点可能会导致对获得响应的影响减少。我们的缓存系统可能会将\"用 10 个词解释法国大革命\"和\"用 100 个词解释法国大革命\"视为相同的查询。如果我们的缓存系统存储模型响应，用户可能会认为他们的指令没有被准确地遵循。\n",
        "\n",
        "但是，两个请求都需要相同的信息来丰富提示。这就是我选择将语义缓存系统放置在用户请求和从向量数据库检索信息之间的主要原因。\n",
        "\n",
        "然而，这是一个设计决策。根据响应类型和系统请求的不同，它可以被放置在一个点或另一个点。很明显，缓存模型响应会节省最多的时间，但正如我已经解释过的，这样做会牺牲用户对响应的影响。\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uizxY8679TDz"
      },
      "source": [
        "# 导入并加载库。\n",
        "首先，我们需要安装必要的 Python 包。\n",
        "* **[sentence transformers](http:/www.sbert.net/)**。这个库用于将句子转换为固定长度的向量，也称为嵌入。\n",
        "* **[xformers](https://github.com/facebookresearch/xformers)**。这是一个提供库和工具的包，以便与 transformers 模型一起使用。我们需要安装它，以避免在处理模型和嵌入时出现错误。\n",
        "* **[chromadb](https://www.trychroma.com/)**。这是我们的向量数据库。ChromaDB 易于使用且开源，可能是用于存储嵌入的最常用的向量数据库。\n",
        "* **[accelerate](https://github.com/huggingface/accelerate)**。在 GPU 上运行模型的必要条件。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:30:10.787688Z",
          "iopub.status.busy": "2024-02-29T17:30:10.787382Z",
          "iopub.status.idle": "2024-02-29T17:34:12.804579Z",
          "shell.execute_reply": "2024-02-29T17:34:12.80338Z",
          "shell.execute_reply.started": "2024-02-29T17:30:10.787657Z"
        },
        "id": "r1nUzd1u9TD0",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "!pip install -q transformers==4.38.1\n",
        "!pip install -q accelerate==0.27.2\n",
        "!pip install -q sentence-transformers==2.5.1\n",
        "!pip install -q xformers==0.0.24\n",
        "!pip install -q chromadb==0.4.24\n",
        "!pip install -q datasets==2.17.1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:35:23.197598Z",
          "iopub.status.busy": "2024-02-29T17:35:23.197205Z",
          "iopub.status.idle": "2024-02-29T17:35:23.202259Z",
          "shell.execute_reply": "2024-02-29T17:35:23.201404Z",
          "shell.execute_reply.started": "2024-02-29T17:35:23.197556Z"
        },
        "id": "5jUwC_eE9TD0",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9P-kYtc79TD1"
      },
      "source": [
        "# 加载数据集\n",
        "\n",
        "由于我们在一个免费且有限的空间中工作，并且只能使用几 GB 的内存，我通过变量 `MAX_ROWS` 限制了从数据集中使用的行数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xZsN8yzUvfjN"
      },
      "outputs": [],
      "source": [
        "#Login to Hugging Face. It is mandatory to use the Gemma Model,\n",
        "#and recommended to acces public models and Datasets.\n",
        "from getpass import getpass\n",
        "if 'hf_key' not in locals():\n",
        "  hf_key = getpass(\"Your Hugging Face API Key: \")\n",
        "!huggingface-cli login --token $hf_key"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "metadata": {
        "id": "9IVxu-uxtCTw"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "data = load_dataset(\"keivalya/MedQuad-MedicalQnADataset\", split='train')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hmor-i1j9TD1"
      },
      "source": [
        "ChromaDB 要求数据具有唯一的标识符。我们可以使用这个语句来创建一个名为**Id**的新列。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 536
        },
        "id": "WbLf8c7_yHwy",
        "outputId": "492eac81-2f7b-4063-f444-405bf489d08e"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "summary": "{\n  \"name\": \"data\",\n  \"rows\": 16407,\n  \"fields\": [\n    {\n      \"column\": \"qtype\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 16,\n        \"samples\": [\n          \"susceptibility\",\n          \"symptoms\",\n          \"information\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"Question\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 14979,\n        \"samples\": [\n          \"What are the symptoms of Danon disease ?\",\n          \"What is (are) Dowling-Degos disease ?\",\n          \"What are the genetic changes related to Pearson marrow-pancreas syndrome ?\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"Answer\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 15817,\n        \"samples\": [\n          \"These resources address the diagnosis or management of glycogen storage disease type III:  - Gene Review: Gene Review: Glycogen Storage Disease Type III  - Genetic Testing Registry: Glycogen storage disease type III   These resources from MedlinePlus offer information about the diagnosis and management of various health conditions:  - Diagnostic Tests  - Drug Therapy  - Surgery and Rehabilitation  - Genetic Counseling   - Palliative Care\",\n          \"Diagnostic Challenges\\n  \\nFor doctors, diagnosing chronic fatigue syndrome (CFS) can be complicated by a number of factors:\\n  \\n   - There's no lab test or biomarker for CFS.\\n   - Fatigue and other symptoms of CFS are common to many illnesses.\\n   - For some CFS patients, it may not be obvious to doctors that they are ill.\\n   - The illness has a pattern of remission and relapse.\\n   - Symptoms vary from person to person in type, number, and severity.\\n  \\n  \\nThese factors have contributed to a low diagnosis rate. Of the one to four million Americans who have CFS, less than 20% have been diagnosed.\\n  Exams and Screening Tests for CFS\\n  \\nBecause there is no blood test, brain scan, or other lab test to diagnose CFS, the doctor should first rule out other possible causes.\\n  \\nIf a patient has had 6 or more consecutive months of severe fatigue that is reported to be unrelieved by sufficient bed rest and that is accompanied by nonspecific symptoms, including flu-like symptoms, generalized pain, and memory problems, the doctor should consider the possibility that the patient may have CFS. Further exams and tests are needed before a diagnosis can be made:\\n  \\n   - A detailed medical history will be needed and should include a review of medications that could be causing the fatigue and symptoms\\n   - A thorough physical and mental status examination will also be needed\\n   - A battery of laboratory screening tests will be needed to help identify or rule out other possible causes of the symptoms that could be treated\\n   - The doctor may also order additional tests to follow up on results of the initial screening tests\\n  \\n  \\nA CFS diagnosis requires that the patient has been fatigued for 6 months or more and has 4 of the 8 symptoms for CFS for 6 months or more. If, however, the patient has been fatigued for 6 months or more but does not have four of the eight symptoms, the diagnosis may be idiopathic fatigue.\\n  \\nThe complete process for diagnosing CFS can be found here.\\n  \\nAdditional information for healthcare professionals on use of tests can be found here.\",\n          \"Eating, diet, and nutrition have not been shown to play a role in causing or preventing simple kidney cysts.\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"id\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 4736,\n        \"min\": 0,\n        \"max\": 16406,\n        \"num_unique_values\": 16407,\n        \"samples\": [\n          3634,\n          15104,\n          4395\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}",
              "type": "dataframe",
              "variable_name": "data"
            },
            "text/html": [
              "\n",
              "  <div id=\"df-e3cca7df-77db-4037-bb3f-d65b3ff8cbb0\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>qtype</th>\n",
              "      <th>Question</th>\n",
              "      <th>Answer</th>\n",
              "      <th>id</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>susceptibility</td>\n",
              "      <td>Who is at risk for Lymphocytic Choriomeningiti...</td>\n",
              "      <td>LCMV infections can occur after exposure to fr...</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>symptoms</td>\n",
              "      <td>What are the symptoms of Lymphocytic Choriomen...</td>\n",
              "      <td>LCMV is most commonly recognized as causing ne...</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>susceptibility</td>\n",
              "      <td>Who is at risk for Lymphocytic Choriomeningiti...</td>\n",
              "      <td>Individuals of all ages who come into contact ...</td>\n",
              "      <td>2</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>exams and tests</td>\n",
              "      <td>How to diagnose Lymphocytic Choriomeningitis (...</td>\n",
              "      <td>During the first phase of the disease, the mos...</td>\n",
              "      <td>3</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>treatment</td>\n",
              "      <td>What are the treatments for Lymphocytic Chorio...</td>\n",
              "      <td>Aseptic meningitis, encephalitis, or meningoen...</td>\n",
              "      <td>4</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>prevention</td>\n",
              "      <td>How to prevent Lymphocytic Choriomeningitis (L...</td>\n",
              "      <td>LCMV infection can be prevented by avoiding co...</td>\n",
              "      <td>5</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>information</td>\n",
              "      <td>What is (are) Parasites - Cysticercosis ?</td>\n",
              "      <td>Cysticercosis is an infection caused by the la...</td>\n",
              "      <td>6</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>susceptibility</td>\n",
              "      <td>Who is at risk for Parasites - Cysticercosis? ?</td>\n",
              "      <td>Cysticercosis is an infection caused by the la...</td>\n",
              "      <td>7</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>exams and tests</td>\n",
              "      <td>How to diagnose Parasites - Cysticercosis ?</td>\n",
              "      <td>If you think that you may have cysticercosis, ...</td>\n",
              "      <td>8</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>treatment</td>\n",
              "      <td>What are the treatments for Parasites - Cystic...</td>\n",
              "      <td>Some people with cysticercosis do not need to ...</td>\n",
              "      <td>9</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e3cca7df-77db-4037-bb3f-d65b3ff8cbb0')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-e3cca7df-77db-4037-bb3f-d65b3ff8cbb0 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-e3cca7df-77db-4037-bb3f-d65b3ff8cbb0');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-8d88a5c2-4d94-419e-a3de-0292c6501384\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-8d88a5c2-4d94-419e-a3de-0292c6501384')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-8d88a5c2-4d94-419e-a3de-0292c6501384 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "text/plain": [
              "             qtype                                           Question  \\\n",
              "0   susceptibility  Who is at risk for Lymphocytic Choriomeningiti...   \n",
              "1         symptoms  What are the symptoms of Lymphocytic Choriomen...   \n",
              "2   susceptibility  Who is at risk for Lymphocytic Choriomeningiti...   \n",
              "3  exams and tests  How to diagnose Lymphocytic Choriomeningitis (...   \n",
              "4        treatment  What are the treatments for Lymphocytic Chorio...   \n",
              "5       prevention  How to prevent Lymphocytic Choriomeningitis (L...   \n",
              "6      information          What is (are) Parasites - Cysticercosis ?   \n",
              "7   susceptibility    Who is at risk for Parasites - Cysticercosis? ?   \n",
              "8  exams and tests        How to diagnose Parasites - Cysticercosis ?   \n",
              "9        treatment  What are the treatments for Parasites - Cystic...   \n",
              "\n",
              "                                              Answer  id  \n",
              "0  LCMV infections can occur after exposure to fr...   0  \n",
              "1  LCMV is most commonly recognized as causing ne...   1  \n",
              "2  Individuals of all ages who come into contact ...   2  \n",
              "3  During the first phase of the disease, the mos...   3  \n",
              "4  Aseptic meningitis, encephalitis, or meningoen...   4  \n",
              "5  LCMV infection can be prevented by avoiding co...   5  \n",
              "6  Cysticercosis is an infection caused by the la...   6  \n",
              "7  Cysticercosis is an infection caused by the la...   7  \n",
              "8  If you think that you may have cysticercosis, ...   8  \n",
              "9  Some people with cysticercosis do not need to ...   9  "
            ]
          },
          "execution_count": 48,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "data = data.to_pandas()\n",
        "data[\"id\"]=data.index\n",
        "data.head(10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:35:25.528374Z",
          "iopub.status.busy": "2024-02-29T17:35:25.527688Z",
          "iopub.status.idle": "2024-02-29T17:35:25.709895Z",
          "shell.execute_reply": "2024-02-29T17:35:25.709127Z",
          "shell.execute_reply.started": "2024-02-29T17:35:25.528341Z"
        },
        "id": "DZf0zCI29TD1",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "MAX_ROWS = 15000\n",
        "DOCUMENT=\"Answer\"\n",
        "TOPIC=\"qtype\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:35:29.184342Z",
          "iopub.status.busy": "2024-02-29T17:35:29.183979Z",
          "iopub.status.idle": "2024-02-29T17:35:29.189229Z",
          "shell.execute_reply": "2024-02-29T17:35:29.1881Z",
          "shell.execute_reply.started": "2024-02-29T17:35:29.184313Z"
        },
        "id": "Mkoj9IrZ9TD1",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "#Because it is just a sample we select a small portion of News.\n",
        "subset_data = data.head(MAX_ROWS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rZHg_Qh69TD1"
      },
      "source": [
        "# 导入并配置向量数据库\n",
        "\n",
        "为了存储信息，我选择使用 ChromaDB，这是最知名且广泛使用的开源向量数据库之一。\n",
        "\n",
        "首先我们需要导入 ChromaDB。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:35:31.849551Z",
          "iopub.status.busy": "2024-02-29T17:35:31.849199Z",
          "iopub.status.idle": "2024-02-29T17:35:32.31736Z",
          "shell.execute_reply": "2024-02-29T17:35:32.316617Z",
          "shell.execute_reply.started": "2024-02-29T17:35:31.849525Z"
        },
        "id": "npJhuZQw9TD1",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "import chromadb"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8okox5C89TD1"
      },
      "source": [
        "现在我们只需要指定存储向量数据库的路径。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:35:34.410646Z",
          "iopub.status.busy": "2024-02-29T17:35:34.410268Z",
          "iopub.status.idle": "2024-02-29T17:35:34.872817Z",
          "shell.execute_reply": "2024-02-29T17:35:34.872039Z",
          "shell.execute_reply.started": "2024-02-29T17:35:34.410614Z"
        },
        "id": "9yK6y0hm9TD1",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "chroma_client = chromadb.PersistentClient(path=\"/path/to/persist/directory\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7MhMwk3J9TD1"
      },
      "source": [
        "# 填充和查询 ChromaDB 数据库\n",
        "\n",
        "ChromaDB 中的数据存储在集合中。如果集合已存在，我们需要删除它。\n",
        "在接下来的行中，我们通过调用上面创建的 `chroma_client` 中的 `create_collection` 函数来创建集合。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:35:36.116012Z",
          "iopub.status.busy": "2024-02-29T17:35:36.1156Z",
          "iopub.status.idle": "2024-02-29T17:35:36.16922Z",
          "shell.execute_reply": "2024-02-29T17:35:36.168504Z",
          "shell.execute_reply.started": "2024-02-29T17:35:36.115977Z"
        },
        "id": "kRCsunE19TD1",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "collection_name = \"news_collection\"\n",
        "if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:\n",
        "    chroma_client.delete_collection(name=collection_name)\n",
        "\n",
        "collection = chroma_client.create_collection(name=collection_name)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rdEtcETr9TD2"
      },
      "source": [
        "现在我们准备好使用 `add` 函数将数据添加到集合中。这个函数需要三个关键信息：\n",
        "\n",
        "* 在 **文档** 中，我们存储数据集中 `Answer` 列的内容。\n",
        "* 在 **元数据** 中，我们可以提供一个主题列表。我使用了 `qtype` 列中的值。\n",
        "* 在 **id** 中，我们需要为每一行提供一个唯一的标识符。我使用 `MAX_ROWS` 的范围来创建ID。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "execution": {
          "iopub.execute_input": "2024-02-29T17:35:38.051601Z",
          "iopub.status.busy": "2024-02-29T17:35:38.051179Z",
          "iopub.status.idle": "2024-02-29T17:36:38.612836Z",
          "shell.execute_reply": "2024-02-29T17:36:38.611814Z",
          "shell.execute_reply.started": "2024-02-29T17:35:38.051569Z"
        },
        "id": "4dDoqJE79TD2",
        "outputId": "36f579dc-ec60-48b1-807a-1e68113cc9f4",
        "trusted": true
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:01<00:00, 68.1MiB/s]\n"
          ]
        }
      ],
      "source": [
        "collection.add(\n",
        "    documents=subset_data[DOCUMENT].tolist(),\n",
        "    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],\n",
        "    ids=[f\"id{x}\" for x in range(MAX_ROWS)],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "du6-iuUisRkM"
      },
      "source": [
        "一旦我们在数据库中有了信息，我们就可以查询它，并请求符合我们需求的数据。搜索是在文档内容内部进行的，它不会查找确切的单词或短语。结果将基于搜索词与文档内容之间的相似性。\n",
        "\n",
        "元数据在初始搜索过程中并不直接参与，它可以在检索后用于过滤或细化结果，从而实现进一步的定制和精确性。\n",
        "\n",
        "让我们定义一个函数来查询 ChromaDB 数据库。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:36:38.616047Z",
          "iopub.status.busy": "2024-02-29T17:36:38.615302Z",
          "iopub.status.idle": "2024-02-29T17:36:38.620516Z",
          "shell.execute_reply": "2024-02-29T17:36:38.619561Z",
          "shell.execute_reply.started": "2024-02-29T17:36:38.616008Z"
        },
        "id": "UjdhZ4MJ9TD2",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "def query_database(query_text, n_results=10):\n",
        "    results = collection.query(query_texts=query_text, n_results=n_results )\n",
        "    return results"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CL0Crl3x9TD2"
      },
      "source": [
        "## 创建语义缓存系统\n",
        "为了实现缓存系统，我们将使用 Faiss 库，该库允许在内存中存储嵌入。这和 Chroma 做的事情很相似，但没有其持久性。\n",
        "\n",
        "为此，我们将创建一个名为 `semantic_cache` 的类，它将使用自己的编码器，并为用户提供执行查询所需的函数。\n",
        "\n",
        "在这个类中，我们首先查询使用 Faiss 实现的缓存，其中包含以前的请求，如果返回的结果超过了一个指定的阈值，它将返回缓存的内容。否则，它将从 Chroma 数据库获取结果。\n",
        "缓存存储在一个 .json 文件中。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:36:38.621968Z",
          "iopub.status.busy": "2024-02-29T17:36:38.621655Z",
          "iopub.status.idle": "2024-02-29T17:36:51.313356Z",
          "shell.execute_reply": "2024-02-29T17:36:51.312232Z",
          "shell.execute_reply.started": "2024-02-29T17:36:38.621936Z"
        },
        "id": "6OzUbRUe9TD2",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "!pip install -q faiss-cpu==1.8.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "0yGE4cTEp3QJ"
      },
      "outputs": [],
      "source": [
        "import faiss\n",
        "from sentence_transformers import SentenceTransformer\n",
        "import time\n",
        "import json"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yi_riXHhcLy0"
      },
      "source": [
        "下面的 `init_cache()` 函数初始化了语义缓存。\n",
        "\n",
        "它使用了 FlatLS 索引，这可能不是最快的，但对于小数据集来说是理想的。如果我们需要根据数据的具体内容和大小来选择缓存（临时存储）数据的方式，我们还可以考虑使用其他的索引方法，比如 HNSW 或 IVF。\n",
        "\n",
        "我选择这个索引是因为它与示例非常契合。它可以用于高维向量，消耗的内存最少，并且在小数据集上表现良好。\n",
        "\n",
        "下面概述了 Faiss 可用的各种索引的关键特性。\n",
        "\n",
        "* FlatL2 或 FlatIP。非常适合小数据集，可能不是最快的，但其内存消耗并不过分。\n",
        "* LSH。它在小数据集上工作效果很好，并且推荐用于最多 128 维的向量。\n",
        "* HNSW。非常快，但需要大量的 RAM。\n",
        "* IVF。在大数据集上工作良好，而且不会消耗太多内存或影响性能。\n",
        "\n",
        "关于 Faiss 可用的不同索引的更多信息可以在以下链接中找到：https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "9poNBxbPl7xE"
      },
      "outputs": [],
      "source": [
        "def init_cache():\n",
        "  index = faiss.IndexFlatL2(768)\n",
        "  if index.is_trained:\n",
        "    print('Index trained')\n",
        "\n",
        "  # Initialize Sentence Transformer model\n",
        "  encoder = SentenceTransformer('all-mpnet-base-v2')\n",
        "\n",
        "  return index, encoder"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_uZzX60odo1U"
      },
      "source": [
        "在 `retrieve_cache` 函数中，.json 文件从磁盘中被检索出来，以便在需要跨会话重用缓存时使用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "FDJJ86TSp5CO"
      },
      "outputs": [],
      "source": [
        "def retrieve_cache(json_file):\n",
        "  try:\n",
        "    with open(json_file, 'r') as file:\n",
        "      cache = json.load(file)\n",
        "  except FileNotFoundError:\n",
        "      cache = {'questions': [], 'embeddings': [], 'answers': [], 'response_text': []}\n",
        "\n",
        "  return cache"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3uO-12UIdtSD"
      },
      "source": [
        "`store_cache` 函数将包含缓存数据的文件保存到磁盘上。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "jx1CiKOcwKGn"
      },
      "outputs": [],
      "source": [
        "def store_cache(json_file, cache):\n",
        "  with open(json_file, 'w') as file:\n",
        "    json.dump(cache, file)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t9AdmnhQd2E8"
      },
      "source": [
        "这些函数将在 `SemanticCache` 类中使用，该类包括搜索函数及其初始化函数。\n",
        "\n",
        "尽管 `ask` 函数的代码量相当大，但它的目的非常直接。它在缓存中查找与用户刚刚提出的问题最接近的问题。\n",
        "\n",
        "然后，检查它是否在指定的阈值内。如果是肯定的，它直接从缓存中返回响应；否则，它调用 `query_database` 函数从 ChromaDB 检索数据。\n",
        "\n",
        "我使用了欧几里得距离而不是广泛应用于向量比较的余弦距离。这个选择是基于欧几里得距离是 Faiss 默认使用的度量标准。尽管也可以计算余弦距离，但这样做会增加复杂性，可能不会显著有助于最终结果。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 51,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:36:51.31678Z",
          "iopub.status.busy": "2024-02-29T17:36:51.316449Z",
          "iopub.status.idle": "2024-02-29T17:36:55.197427Z",
          "shell.execute_reply": "2024-02-29T17:36:55.196616Z",
          "shell.execute_reply.started": "2024-02-29T17:36:51.316746Z"
        },
        "id": "t_HVtwww9TD2",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "class semantic_cache:\n",
        "  def __init__(self, json_file=\"cache_file.json\", thresold=0.35):\n",
        "      # Initialize Faiss index with Euclidean distance\n",
        "      self.index, self.encoder = init_cache()\n",
        "\n",
        "      # Set Euclidean distance threshold\n",
        "      # a distance of 0 means identicals sentences\n",
        "      # We only return from cache sentences under this thresold\n",
        "      self.euclidean_threshold = thresold\n",
        "\n",
        "      self.json_file = json_file\n",
        "      self.cache = retrieve_cache(self.json_file)\n",
        "\n",
        "  def ask(self, question: str) -> str:\n",
        "      # Method to retrieve an answer from the cache or generate a new one\n",
        "      start_time = time.time()\n",
        "      try:\n",
        "          #First we obtain the embeddings corresponding to the user question\n",
        "          embedding = self.encoder.encode([question])\n",
        "\n",
        "          # Search for the nearest neighbor in the index\n",
        "          self.index.nprobe = 8\n",
        "          D, I = self.index.search(embedding, 1)\n",
        "\n",
        "          if D[0] >= 0:\n",
        "              if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:\n",
        "                  row_id = int(I[0][0])\n",
        "\n",
        "                  print('Answer recovered from Cache. ')\n",
        "                  print(f'{D[0][0]:.3f} smaller than {self.euclidean_threshold}')\n",
        "                  print(f'Found cache in row: {row_id} with score {D[0][0]:.3f}')\n",
        "                  print(f'response_text: ' + self.cache['response_text'][row_id])\n",
        "\n",
        "                  end_time = time.time()\n",
        "                  elapsed_time = end_time - start_time\n",
        "                  print(f\"Time taken: {elapsed_time:.3f} seconds\")\n",
        "                  return self.cache['response_text'][row_id]\n",
        "\n",
        "          # Handle the case when there are not enough results\n",
        "          # or Euclidean distance is not met, asking to chromaDB.\n",
        "          answer  = query_database([question], 1)\n",
        "          response_text = answer['documents'][0][0]\n",
        "\n",
        "          self.cache['questions'].append(question)\n",
        "          self.cache['embeddings'].append(embedding[0].tolist())\n",
        "          self.cache['answers'].append(answer)\n",
        "          self.cache['response_text'].append(response_text)\n",
        "\n",
        "          print('Answer recovered from ChromaDB. ')\n",
        "          print(f'response_text: {response_text}')\n",
        "\n",
        "          self.index.add(embedding)\n",
        "          store_cache(self.json_file, self.cache)\n",
        "          end_time = time.time()\n",
        "          elapsed_time = end_time - start_time\n",
        "          print(f\"Time taken: {elapsed_time:.3f} seconds\")\n",
        "\n",
        "          return response_text\n",
        "      except Exception as e:\n",
        "          raise RuntimeError(f\"Error during 'ask' method: {e}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UBWTqGM7i71N"
      },
      "source": [
        "### 测试 semantic_cache 类。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 52,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JH8s8eUtCMIS",
        "outputId": "c613bbfc-9f84-4a96-cd39-45972e69c15b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Index trained\n"
          ]
        }
      ],
      "source": [
        "# Initialize the cache.\n",
        "cache = semantic_cache('4cache.json')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 53,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mKqKLfDe_8bC",
        "outputId": "8a92ed95-c822-4382-c6db-d9de289341af"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Answer recovered from ChromaDB. \n",
            "response_text: Summary : Shots may hurt a little, but the diseases they can prevent are a lot worse. Some are even life-threatening. Immunization shots, or vaccinations, are essential. They protect against things like measles, mumps, rubella, hepatitis B, polio, tetanus, diphtheria, and pertussis (whooping cough). Immunizations are important for adults as well as children.    Your immune system helps your body fight germs by producing substances to combat them. Once it does, the immune system \"remembers\" the germ and can fight it again. Vaccines contain germs that have been killed or weakened. When given to a healthy person, the vaccine triggers the immune system to respond and thus build immunity.     Before vaccines, people became immune only by actually getting a disease and surviving it. Immunizations are an easier and less risky way to become immune.     NIH: National Institute of Allergy and Infectious Diseases\n",
            "Time taken: 0.057 seconds\n"
          ]
        }
      ],
      "source": [
        "results = cache.ask(\"How do vaccines work?\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dP7H6TypknLN"
      },
      "source": [
        "正如预期的那样，这个响应是从 ChromaDB 获取的。然后，该类将其存储在缓存中。\n",
        "\n",
        "现在，如果我们发送一个完全不同的问题，响应也应该从 ChromaDB 中检索。这是因为先前存储的问题与当前问题如此不同，以至于它在欧几里得距离方面会超过指定的阈值。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 54,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "execution": {
          "iopub.execute_input": "2024-02-29T17:37:15.335593Z",
          "iopub.status.busy": "2024-02-29T17:37:15.335288Z",
          "iopub.status.idle": "2024-02-29T17:37:17.320691Z",
          "shell.execute_reply": "2024-02-29T17:37:17.319671Z",
          "shell.execute_reply.started": "2024-02-29T17:37:15.335566Z"
        },
        "id": "CvJykqVf9TD2",
        "outputId": "7137919e-e417-47b3-a638-18026b3edfe6",
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Answer recovered from ChromaDB. \n",
            "response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.\n",
            "Time taken: 0.082 seconds\n"
          ]
        }
      ],
      "source": [
        "\n",
        "results = cache.ask(\"Explain briefly what is a Sydenham chorea\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8aPWvU64lxOU"
      },
      "source": [
        "完美，语义缓存系统正如预期那样运行。\n",
        "\n",
        "让我们继续用一个非常类似于我们刚才问的问题来测试它。\n",
        "\n",
        "在这种情况下，响应应该直接来自缓存，而不需要访问 ChromaDB 数据库。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 55,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "execution": {
          "iopub.execute_input": "2024-02-29T17:37:17.328926Z",
          "iopub.status.busy": "2024-02-29T17:37:17.32865Z",
          "iopub.status.idle": "2024-02-29T17:37:17.463363Z",
          "shell.execute_reply": "2024-02-29T17:37:17.462397Z",
          "shell.execute_reply.started": "2024-02-29T17:37:17.328902Z"
        },
        "id": "9_5IcGB-9TD2",
        "outputId": "13563a7d-01f7-47d1-c345-6ad128f303c3",
        "trusted": true
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Answer recovered from Cache. \n",
            "0.028 smaller than 0.35\n",
            "Found cache in row: 1 with score 0.028\n",
            "response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.\n",
            "Time taken: 0.019 seconds\n"
          ]
        }
      ],
      "source": [
        "results = cache.ask(\"Briefly explain me what is a Sydenham chorea.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M4H8RoXFqdwE"
      },
      "source": [
        "这两个问题非常相似，它们的欧几里得距离非常小，几乎就像它们是相同的。\n",
        "\n",
        "现在，让我们尝试另一个问题，这次稍微有些不同，观察系统的表现。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 56,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ysj5P_MBCqju",
        "outputId": "d4639f73-dc7e-4c25-93ba-2a8c66dc7c61"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Answer recovered from Cache. \n",
            "0.228 smaller than 0.35\n",
            "Found cache in row: 1 with score 0.228\n",
            "response_text: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.\n",
            "Time taken: 0.016 seconds\n"
          ]
        }
      ],
      "source": [
        "question_def = \"Write in 20 words what is a Sydenham chorea.\"\n",
        "results = cache.ask(question_def)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MFzXsQwB9TD3"
      },
      "source": [
        "我们观察到欧几里得距离已经增加，但它仍然在指定的阈值范围内。因此，它继续直接从缓存中返回响应。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ot3wrq0p9TD3"
      },
      "source": [
        "# 加载模型并创建提示\n",
        "\n",
        "是时候使用 **transformers** 库了，这是[ hugging face ](https://huggingface.co/)最著名的库，用于处理语言模型。\n",
        "\n",
        "我们将导入：\n",
        "* **Autotokenizer**：这是一个实用程序类，用于标记化与各种预训练语言模型兼容的文本输入。\n",
        "* **AutoModelForCausalLM**：它提供了一个接口，用于预训练的语言模型，特别适用于使用因果语言建模（例如，GPT 模型）的语言生成任务，或者是这个 Notebook 中使用的模型 [Gemma-2b-it](https://huggingface.co/google/gemma-2b-it)。\n",
        "请随意测试 [不同的模型](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending)，你需要搜索训练用于文本生成的 NLP 模型。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:40:32.797669Z",
          "iopub.status.busy": "2024-02-29T17:40:32.797334Z",
          "iopub.status.idle": "2024-02-29T17:40:44.152114Z",
          "shell.execute_reply": "2024-02-29T17:40:44.151056Z",
          "shell.execute_reply.started": "2024-02-29T17:40:32.797635Z"
        },
        "id": "tdxiKqjT9TD3",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "!pip install torch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:40:44.15434Z",
          "iopub.status.busy": "2024-02-29T17:40:44.153914Z",
          "iopub.status.idle": "2024-02-29T17:40:44.160144Z",
          "shell.execute_reply": "2024-02-29T17:40:44.159154Z",
          "shell.execute_reply.started": "2024-02-29T17:40:44.154292Z"
        },
        "id": "pIDMTCnH9TD7",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "from torch import cuda, torch\n",
        "#In a MAC Silicon the device must be 'mps'\n",
        "# device = torch.device('mps') #to use with MAC Silicon\n",
        "device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {
          "iopub.execute_input": "2024-02-29T17:41:25.628804Z",
          "iopub.status.busy": "2024-02-29T17:41:25.628412Z",
          "iopub.status.idle": "2024-02-29T17:41:30.202141Z",
          "shell.execute_reply": "2024-02-29T17:41:30.200774Z",
          "shell.execute_reply.started": "2024-02-29T17:41:25.628766Z"
        },
        "id": "CU2T4lp-9TD7",
        "trusted": true
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "model_id = \"google/gemma-2b-it\"\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "model = AutoModelForCausalLM.from_pretrained(model_id,\n",
        "                                             device_map=\"cuda\",\n",
        "                                            torch_dtype=torch.bfloat16)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GzHuFrAX9TD7"
      },
      "source": [
        "## 创建扩展提示\n",
        "\n",
        "为了创建提示，我们使用从查询 'semantic_cache' 类得到的结果以及用户提出的问题。\n",
        "\n",
        "提示有两部分，**相关上下文**是从数据库中恢复的信息，以及**用户的问题**。\n",
        "\n",
        "我们只需要将这两部分放在一起来创建提示，然后将其发送给模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 209
        },
        "id": "TdjbfAHhFuhS",
        "outputId": "4090da66-328e-478e-c2d7-1957597f8786"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "\"Relevant context: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.\\n\\n The user's question: Write in 20 words what is a Sydenham chorea.\""
            ]
          },
          "execution_count": 44,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "prompt_template = f\"Relevant context: {results}\\n\\n The user's question: {question_def}\"\n",
        "prompt_template"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "id": "DmYAcXEEECnz"
      },
      "outputs": [],
      "source": [
        "input_ids = tokenizer(prompt_template, return_tensors=\"pt\").to(\"cuda\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S-QXeuJ09TD8"
      },
      "source": [
        "现在剩下的就是将提示发送给模型，等待它的响应！\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lheL8vHpEMDD",
        "outputId": "b646d648-b88d-4a29-ab30-427d00296255"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<bos>Relevant context: Sydenham chorea (SD) is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS), the bacterium that causes rheumatic fever. SD is characterized by rapid, irregular, and aimless involuntary movements of the arms and legs, trunk, and facial muscles. It affects girls more often than boys and typically occurs between 5 and 15 years of age. Some children will have a sore throat several weeks before the symptoms begin, but the disorder can also strike up to 6 months after the fever or infection has cleared. Symptoms can appear gradually or all at once, and also may include uncoordinated movements, muscular weakness, stumbling and falling, slurred speech, difficulty concentrating and writing, and emotional instability. The symptoms of SD can vary from a halting gait and slight grimacing to involuntary movements that are frequent and severe enough to be incapacitating. The random, writhing movements of chorea are caused by an auto-immune reaction to the bacterium that interferes with the normal function of a part of the brain (the basal ganglia) that controls motor movements. Due to better sanitary conditions and the use of antibiotics to treat streptococcal infections, rheumatic fever, and consequently SD, are rare in North America and Europe. The disease can still be found in developing nations.\n",
            "\n",
            " The user's question: Write in 20 words what is a Sydenham chorea.\n",
            "\n",
            "Sure, here is a 20-word answer:\n",
            "\n",
            "Sydenham chorea is a neurological disorder of childhood resulting from infection via Group A beta-hemolytic streptococcus (GABHS).<eos>\n"
          ]
        }
      ],
      "source": [
        "outputs = model.generate(**input_ids,\n",
        "                         max_new_tokens=256)\n",
        "print(tokenizer.decode(outputs[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {
          "iopub.execute_input": "2023-07-12T22:01:56.993351Z",
          "iopub.status.busy": "2023-07-12T22:01:56.992775Z",
          "iopub.status.idle": "2023-07-12T22:01:57.001309Z",
          "shell.execute_reply": "2023-07-12T22:01:56.999431Z",
          "shell.execute_reply.started": "2023-07-12T22:01:56.993305Z"
        },
        "id": "Uo7lGXBV9TD8"
      },
      "source": [
        "# 结论\n",
        "\n",
        "在访问 ChromaDB 和直接访问缓存之间，数据检索时间减少了 50%。然而，在更大的项目中，这种差异会增加，导致性能提升达到 90-95%。\n",
        "\n",
        "我们在 Chroma 中的数据非常少，只有一个缓存类的实例。通常，缓存系统背后的数据要大得多，可能不仅仅是对向量数据库的查询，而是来自各种来源。\n",
        "\n",
        "通常会有多个缓存类的实例，通常基于用户类型，因为共享共同特征的用户之间的问题往往更容易重复。\n",
        "\n",
        "总之，我们创建了一个非常简单的 RAG 系统，并通过在用户的问题和获取创建丰富提示所需信息之间增加一个语义缓存层来增强它。\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "machine_shape": "hm",
      "provenance": []
    },
    "kaggle": {
      "accelerator": "gpu",
      "dataSources": [
        {
          "datasetId": 3496946,
          "sourceId": 6104553,
          "sourceType": "datasetVersion"
        }
      ],
      "dockerImageVersionId": 30527,
      "isGpuEnabled": true,
      "isInternetEnabled": true,
      "language": "python",
      "sourceType": "notebook"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
